#include "faiss_index_factory.h"

#include <regex>
#include "IndexIVFPQ.h"
#include "IndexHNSW.h"
#include "index_io.h"

#include "IndexIVFFlat.h"
#include "log.h"

VET_INDEX_NAMESPACE_BEGIN

IndexFactory::IndexFactory()
    : p_app_field_table_(const_cast<AppFieldContext::AppFieldTable*>(
                &(ConfigManager::Instance()->GetAppFieldTable()->o_app_field_table_) ))
    , index_map_()
    , idx_num_map_()
    , train_idx_map_()
{

}

IndexFactory::~IndexFactory()
{

}

bool IndexFactory::InitIndex(bool b_train)
{
    AppFieldContext::AppFileldTableIterator iter1 = 
                p_app_field_table_->begin();
    
    for ( ; iter1 != p_app_field_table_->end(); ++iter1) {
        AppFieldContext::FieldIdVetIterator iter2 = iter1->second.begin();

        for ( ; iter2 != iter1->second.end(); ++iter2) {
            std::vector<std::string>* p_index_type_str = 
                        &(iter2->second.index_type_str_array_);
            
            std::vector<std::string>* p_index_dir_str = 
                        &(iter2->second.index_dir_array_);

            size_t i_min_size = MIN(p_index_type_str->size(),
                                p_index_dir_str->size());
            
            for (size_t i = 0; i < i_min_size; ++i) {
                log_debug("appid:%d,fieldid:%d,idxTypePos:%d,dim:%d,modelType:%s,modelDir:%s",
                        iter1->first,
                        iter2->first,
                        (int)i,
                        iter2->second.ui_dim_,
                        (*p_index_type_str)[i].c_str(),
                        (*p_index_dir_str)[i].c_str() );
                try {
                    bool b_ret = true;
                    if(!b_train) {
                        b_ret = init_single_index(
                                iter1->first,
                                iter2->first,
                                i,
                                iter2->second.ui_dim_,
                                (*p_index_type_str)[i],
                                (*p_index_dir_str)[i]);
                        
                    } else {
                        b_ret = init_index_for_train(
                                iter1->first,
                                iter2->first,
                                i,
                                iter2->second.ui_dim_,
                                (*p_index_type_str)[i]);
                    }
                    if (!b_ret) {
                        return false;
                    }
                }
                catch(const std::regex_error& e) {
                    log_error("error:%d",e.code());
                    return false;
                }
                catch(const faiss::FaissException& e) {
                    log_error("error:%s",e.what());
                    return false;
                }
                catch(...) {
                    log_error("uknow exception error");
                    return false;
                }
            }
            idx_num_map_.insert(std::make_pair(
                UniqueIndexFlag(iter1->first,
                                iter2->first,
                                0),
                            i_min_size));
        }
    }

    return true;
}

bool IndexFactory::CheckUniIdxIsValid(
    const UniqueIndexFlag& o_unique_id)
{
    bool b_ret = false;
    if (index_map_.find(o_unique_id) != index_map_.end()) {
        b_ret = true;
    }
    return b_ret;
}

int64_t IndexFactory::GetIdxNum(
    const UniqueIndexFlag& o_unique_id)
{
    if (idx_num_map_.find(o_unique_id) != idx_num_map_.end()) {
        return idx_num_map_[o_unique_id];
    } else {
        return -1;
    }
}

bool IndexFactory::init_single_index(
    int i_appid ,
    int i_fieldid ,
    int i_index_typeid,
    uint16_t ui_dim,
    const std::string& s_single_index,
    const std::string& s_single_dir)
{
    if (std::regex_match(
            s_single_index ,
            std::regex("^IVF([0-9]+),Flat$"))) {
        log_debug("here is IVFx,Flat init");
        faiss::IndexIVFFlat* p_index_ivfflat = 
            dynamic_cast<faiss::IndexIVFFlat*>(
                    read_index_model_file(s_single_dir));
        if (NULL == p_index_ivfflat) {
            log_error("ivfflat is not match with model file");
            return false;
        }

        index_map_.insert(std::make_pair(
            UniqueIndexFlag(i_appid,
                            i_fieldid,
                            i_index_typeid),
            new IndexIVFFlat(p_index_ivfflat)));
    } else if (std::regex_match(
                s_single_index ,
                std::regex("^IDMap,HNSW([0-9]+),Flat$"))) {
        log_debug("here is IDMap,HNSWx,Flat init");
        faiss::Index* p_index_hnswflat = 
                faiss::index_factory(ui_dim , s_single_index.c_str());
        if (NULL == p_index_hnswflat) {
            log_error("hnswflat str is not match with faiss index_factory");
            return false;
        }

        index_map_.insert(std::make_pair(
            UniqueIndexFlag(i_appid,
                            i_fieldid,
                            i_index_typeid),
            new FaissAptBase(p_index_hnswflat , false)));
    } else if (std::regex_match(
                s_single_index ,
                std::regex("^IVF([0-9]+),PQ([0-9]+)$"))) {
        faiss::IndexIVFPQ* p_index_ivfpq = 
            dynamic_cast<faiss::IndexIVFPQ*>(
                    read_index_model_file(s_single_dir));
        if (NULL == p_index_ivfpq) {
            log_error("ivfpq is not match with model file");
            return false;
        }

        index_map_.insert(std::make_pair(
            UniqueIndexFlag(i_appid,
                            i_fieldid,
                            i_index_typeid),
            new IndexIVFPQ(p_index_ivfpq)));
    } else {
        log_error("index type regex error,please check app config");
        return false;
    }

    return true;
}

bool IndexFactory::init_index_for_train(
    int i_appid ,
    int i_fieldid ,
    int i_index_typeid,
    uint16_t ui_dim ,
    const std::string& s_single_index)
{
    printf("here is for model train");
    faiss::Index* p_index = faiss::index_factory(ui_dim , s_single_index.c_str());
    if (NULL == p_index) {
        printf(" %s is not match with faiss index_factory" , s_single_index.c_str());
        return false;
    }

    train_idx_map_.insert(std::make_pair(
            UniqueIndexFlag(i_appid, i_fieldid, i_index_typeid) , p_index) );
    
    return true;
}

faiss::Index* IndexFactory::read_index_model_file(
    const std::string& s_single_dir)
{
    faiss::Index* p_idx = faiss::read_index(s_single_dir.c_str());
    if (NULL == p_idx) {
        log_error("fread faissindex init error");
    }
    return p_idx;
}

VET_INDEX_NAMESPACE_END
